#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
Created on 2016年12月9日

@author: yangzhou1
线性单元
'''
from neurals.Perceptron import Perceptron

class LinearUnit(Perceptron):
    '''
    线性单元，激活函数为线性函数
    '''
    def __init__(self, input_num):
        f = lambda x:x
        Perceptron.__init__(self,input_num,f)

def train_linear_unit():
    #工作年限
    input_vecs=[[5],[3],[8],[1.4],[10.1]]
    #月薪
    labels=[5500,2300,7600,1800,11400]
    #输入的特征数为1，工作年限
    linearUnit=LinearUnit(1)
    # 训练，迭代10轮, 学习速率为0.01
    linearUnit.train(input_vecs,labels,10,0.01)
    return linearUnit
if __name__=='__main__':
    linear_unit=train_linear_unit()
    print (linear_unit)
    print ('Work 3.4 years,monthly salary = %.2f' % linear_unit.predict([3.4]))
    print ('Work 15 years,monthly salary = %.2f' % linear_unit.predict([15]))
    print ('Work 1.5 years,monthly salary = %.2f' % linear_unit.predict([1.5]))
    print ('Work 6.3 years,monthly salary = %.2f' % linear_unit.predict([6.3]))